package vulnerability

import (
	"cmp"
	"slices"
	"strings"
	"sync"

	"github.com/samber/lo"

	"github.com/aquasecurity/trivy-db/pkg/db"
	dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
	"github.com/aquasecurity/trivy-db/pkg/vulnsrc/vulnerability"
	"github.com/aquasecurity/trivy/pkg/log"
	"github.com/aquasecurity/trivy/pkg/types"
	"github.com/aquasecurity/trivy/pkg/version/doc"
)

var (
	primaryURLPrefixes = map[dbTypes.SourceID][]string{
		vulnerability.Debian: {
			"http://www.debian.org",
			"https://www.debian.org",
		},
		vulnerability.Ubuntu: {
			"http://www.ubuntu.com",
			"https://usn.ubuntu.com",
		},
		vulnerability.RedHat: {"https://access.redhat.com"},
		vulnerability.SuseCVRF: {
			"http://lists.opensuse.org",
			"https://lists.opensuse.org",
		},
		vulnerability.OracleOVAL: {
			"http://linux.oracle.com/errata",
			"https://linux.oracle.com/errata",
		},
		vulnerability.NodejsSecurityWg: {
			"https://www.npmjs.com",
			"https://hackerone.com",
		},
		vulnerability.RubySec: {"https://groups.google.com"},
	}
)

// Show warning if we use severity from another vendor
// cf. https://github.com/aquasecurity/trivy/issues/6714
var onceWarn = sync.OnceFunc(func() {
	// e.g. https://trivy.dev/docs/latest/scanner/vulnerability/#severity-selection
	log.Warnf("Using severities from other vendors for some vulnerabilities. Read %s for details.", doc.URL("guide/scanner/vulnerability/", "severity-selection"))
})

// Client manipulates vulnerabilities
type Client struct {
	dbc db.Operation
}

// NewClient is the factory method for Client
func NewClient(dbc db.Operation) Client {
	return Client{dbc: dbc}
}

// FillInfo fills extra info in vulnerability objects
func (c Client) FillInfo(vulns []types.DetectedVulnerability, severitySources []dbTypes.SourceID) {
	var noSeverityIDs []string
	for i := range vulns {
		// Add the vulnerability status
		// Some vendors such as Red Hat have their own vulnerability status, and we use it.
		// Otherwise, we put "fixed" or "affected" according to the fixed version.
		if vulns[i].FixedVersion != "" {
			vulns[i].Status = dbTypes.StatusFixed
		} else if vulns[i].Status == dbTypes.StatusUnknown {
			vulns[i].Status = dbTypes.StatusAffected
		}

		// Get the vulnerability detail
		vulnID := vulns[i].VulnerabilityID
		vuln, err := c.dbc.GetVulnerability(vulnID)
		if err != nil {
			log.Warn("Unable to get vulnerability details (CVE may be rejected)", log.Err(err))
			continue
		}

		// Detect the data source
		dataSource := lo.FromPtr(vulns[i].DataSource)

		// To determine severity, use BaseID if available, otherwise use ID.
		// e.g. `debian` severity for root.io advisories.
		dataSourceID := cmp.Or(dataSource.BaseID, dataSource.ID)

		// Select the severity according to the detected sourceID.
		severity, severitySource := c.getSeverity(vulnID, &vuln, dataSourceID, severitySources)
		if severity == dbTypes.SeverityUnknown.String() {
			noSeverityIDs = append(noSeverityIDs, vulnID)
		}

		// The vendor might provide package-specific severity like Debian.
		// For example, CVE-2015-2328 in Debian has "unimportant" for mongodb and "low" for pcre3.
		// In that case, we keep the severity as is.
		if vulns[i].SeveritySource != "" {
			severity = vulns[i].Severity
			severitySource = vulns[i].SeveritySource

			// Store package-specific severity in vendor severities
			if vuln.VendorSeverity == nil {
				vuln.VendorSeverity = make(dbTypes.VendorSeverity)
			}
			s, _ := dbTypes.NewSeverity(severity) // skip error handling because `SeverityUnknown` will be returned in case of error
			vuln.VendorSeverity[severitySource] = s
		}

		// Add the vulnerability detail
		vulns[i].Vulnerability = vuln

		vulns[i].Severity = severity
		vulns[i].SeveritySource = severitySource
		vulns[i].PrimaryURL = c.getPrimaryURL(vulnID, vuln.References, dataSource.ID)
	}

	if !slices.Contains(severitySources, "auto") && len(noSeverityIDs) > 0 {
		log.Warn("No severity found in specified sources",
			log.Any("vulnerability-ids", noSeverityIDs), log.Any("severity-sources", severitySources))
	}
}

func (c Client) getSeverity(vulnID string, vuln *dbTypes.Vulnerability, dataSourceID dbTypes.SourceID, severitySources []dbTypes.SourceID) (string, dbTypes.SourceID) {
	for _, source := range severitySources {
		if source == "auto" {
			return c.autoDetectSeverity(vulnID, vuln, dataSourceID)
		}

		if severity, ok := vuln.VendorSeverity[source]; ok {
			return severity.String(), source
		}
	}

	return dbTypes.SeverityUnknown.String(), ""
}

// autoDetectSeverity detects the severity from the vulnerability ID and data source.
//
// The severity is determined in the following order:
//  1. If the vulnerability is from a specific data source (e.g., Red Hat advisories for Red Hat distributions),
//     use the severity from that data source.
//  2. For GHSA-IDs, also consider the severity from GitHub Advisory Database.
//  3. Use the severity from NVD as a fallback.
//  4. Try severities from other data sources (e.g., Debian severity for Red Hat distributions).
//  5. If no severity is found from any data source, return "UNKNOWN".
func (c Client) autoDetectSeverity(vulnID string, vuln *dbTypes.Vulnerability, dataSourceID dbTypes.SourceID) (string, dbTypes.SourceID) {
	autoSeveritySrcs := []dbTypes.SourceID{dataSourceID, vulnerability.NVD}
	if vs, ok := vuln.VendorSeverity[dataSourceID]; ok {
		return vs.String(), dataSourceID
	}

	// use severity from GitHub for all GHSA-xxx vulnerabilities
	if strings.HasPrefix(vulnID, "GHSA-") {
		// use severity from GitHub for all GHSA-IDs
		autoSeveritySrcs = []dbTypes.SourceID{dataSourceID, vulnerability.GHSA, vulnerability.NVD}
	}

	if severity, severitySource := c.getSeverity(vulnID, vuln, dataSourceID, autoSeveritySrcs); severity != dbTypes.SeverityUnknown.String() {
		return severity, severitySource
	}

	if vuln.Severity == "" {
		return dbTypes.SeverityUnknown.String(), ""
	}

	onceWarn()
	return vuln.Severity, ""
}

func (c Client) getPrimaryURL(vulnID string, refs []string, source dbTypes.SourceID) string {
	switch {
	case strings.HasPrefix(vulnID, "CVE-"):
		return "https://avd.aquasec.com/nvd/" + strings.ToLower(vulnID)
	case strings.HasPrefix(vulnID, "RUSTSEC-"):
		return "https://osv.dev/vulnerability/" + vulnID
	case strings.HasPrefix(vulnID, "GHSA-"):
		return "https://github.com/advisories/" + vulnID
	case strings.HasPrefix(vulnID, "TEMP-"):
		return "https://security-tracker.debian.org/tracker/" + vulnID
	}

	prefixes := primaryURLPrefixes[source]
	for _, pre := range prefixes {
		for _, ref := range refs {
			if strings.HasPrefix(ref, pre) {
				return ref
			}
		}
	}
	return ""
}
